
using DataStructures, Tables, Plots, CSV, DataFrames, MAT, Distributions, LinearAlgebra, SparseArrays, JLD2, HDF5, Random, StatsFuns, Pkg
Pkg.activate(".")
#import Pkg; Pkg.add("PyPLot")

dfs = Dict(year=> CSV.read(outputpath*"/simulate$(year)_big.csv",DataFrame) for year in 2010:2012)

function getFitStats(year,varname,dataset=dataset,counterfactuals=counterfactuals)
    df = dfs[year]
    cc = dataset[year][1]
    cf_constants = counterfactuals[year][1]
    J = length(cf_constants.seats)
    J_on = findall(cc.platform)
    inds_on = cf_constants.typeIndex[J_on,4]
    scores =  (vec(sum(100 .* (cf_constants.scores .+ 3),dims=2)./2))
    placements = cf_constants.placementIndex #[(x==0 ? 0 : cf_constants.typeIndex[x,4]) for x in cf_constants.placementIndex]
    enroll = cc.enroll
    sumscore_model = zeros(J)
    N_model = zeros(J)
    sumscore_data = zeros(J)
    N_data = zeros(J)
    #
    println("varname==$varname")
    @assert varname in ["placement","enroll"]
    var = varname=="placement" ? placements : enroll
    colnames = [Symbol("simulate$(year)a_$varname"),Symbol("simulate$(year)b_$varname")]
    for row in eachrow(df)
        sc = scores[row.ind]
        j0 = var[row.ind]
        if (row.draw==1) && (j0 > 0)
            sumscore_data[j0] += sc
            N_data[j0] += 1
        end
        for col in colnames
            jj = row[col]
            if jj>0
                sumscore_model[jj] += sc
                N_model[jj] += 1
            end
        end
    end
    sumscore_data,sumscore_model,N_data,N_model
end


for year in 2010:2012, vv in ["placement","enroll"]
    sumscore_data,sumscore_model,N_data,N_model = getFitStats(year,vv)
    meanscore_data = sumscore_data./N_data
    meanscore_model = sumscore_model./N_model
    cc = dataset[year][1]
    cf_constants = counterfactuals[year][1]
    #######################
    # fig 1: predicted vs. observed mean scores
    #######################
    keepif = findall(cc.platform .& (cf_constants.seats .> 50))
    keepif2 = findall(cc.platform .& (cf_constants.seats .> 100))
    
    pyplot()
    plt = scatter(meanscore_data[keepif],meanscore_model[keepif],color=:blue,fillalpha=0.1,xticks=450:50:800, yticks=450:50:800)
    scatter!(plt,meanscore_data[keepif2],meanscore_model[keepif2],color=:red)
    plot!(plt,450:800,450:800,color=:black,style=:dashdot,legend=:false)
    xlabel!(plt,"Average Scores - Data")
    ylabel!(plt,"Average Scores - Simulation")
    savefig(plt,figurepath*"/fit_scores_$(vv)_$year.png")
    savefig(plt,figurepath*"/fit_scores_$(vv)_$year.eps")
    df_plt = DataFrame(meanscore_data = meanscore_data,
      meanscore_model=meanscore_model,
      platform=cc.platform,
      seats = cf_constants.seats,
      show_blue = cc.platform .& (100 .>= cf_constants.seats .> 50),
      show_red = cc.platform .& (cf_constants.seats .> 100))
    CSV.write(figurepath*"/fit_scores_$(vv)_$year.csv",df_plt)
end

#######################
# fig 2: BVP
#######################
function getBVPStats(year,vv="enroll",dataset=dataset,counterfactuals=counterfactuals)
    @assert vv in ["placement","enroll"]
    df = dfs[year]
    cc = dataset[year][1]
    cf_constants = counterfactuals[year][1]
    var = vv=="placement" ? cf_constants.placementIndex : cc.enroll
    teaching = cf_constants.typeIndex[:,1] .== 7
    J = length(cf_constants.seats)
    J_on = findall(cc.platform)
    inds_on = cf_constants.typeIndex[J_on,4]
    scores =  (vec(sum(100 .* (cf_constants.scores .+ 3),dims=2)./2))
    score_round = Int.( 10 .* round.(  (5 .+ scores)./10) ) .- 5
    binmeans = 455:10:805
    bins = OrderedDict([v=>k for (k,v) in enumerate(binmeans)]...)
    #data stuff
    nteach = zeros(Int,length(bins))
    nn = zeros(Int,length(bins))
    for ii=1:length(cc.enroll)
        (450 <= score_round[ii] <= 805) || continue
        bin = bins[score_round[ii]]
        if var[ii] > 0
            nn[bin] += 1
            nteach[bin] += teaching[var[ii]]
        end
    end
    #model stuff
    nteach_m = zeros(Int,length(bins))
    nn_m = zeros(Int,length(bins))
    colnames = [Symbol("simulate$(year)a_$vv"),Symbol("simulate$(year)b_$vv")]
    for row in eachrow(df)
        (450 <= score_round[row.ind] <= 805) || continue
        bin = bins[score_round[row.ind]]
        for col in colnames
            if row[col] > 0
                nn_m[bin] += 1
                nteach_m[bin] += teaching[row[col]]
            end
        end
    end
    return nteach./nn, nteach_m./nn_m, binmeans
end


for vv in ["enroll","placement"]
    prteach2010 = getBVPStats(2010,vv)
    prteach2011 = getBVPStats(2011,vv)
    #
    inds = prteach2010[3]
    inds1 = findall(inds .< 500)
    inds2 = findall(500 .< inds .< 600)
    inds3 = findall(600 .< inds)
    pyplot()
    pltBVP = plot(inds,prteach2010[1],color=:blue) #data
    plot!(pltBVP, inds[inds1],prteach2011[1][inds1],color=:red)
    plot!(pltBVP, inds[inds2],prteach2011[1][inds2],color=:red)
    plot!(pltBVP, inds[inds3],prteach2011[1][inds3],color=:red,legend=:false)
    innertxt = vv=="enroll" ? "Enrolling in" : "Placing in"
    xlabel!(pltBVP,"Probability of $(innertxt) Teaching: 2010 vs. 2011")
    savefig(pltBVP,figurepath*"/BVP_dataonly_$(vv).png")
    savefig(pltBVP,figurepath*"/BVP_dataonly_$(vv).eps")
    plot!(pltBVP,inds,prteach2010[2],linestyle=:dot,color=:blue) #model
    plot!(pltBVP, inds[inds1],prteach2011[2][inds1],linestyle = :dot,color=:red)
    plot!(pltBVP, inds[inds2],prteach2011[2][inds2],linestyle = :dot,color=:red)
    plot!(pltBVP, inds[inds3],prteach2011[2][inds3],linestyle = :dot,color=:red, legend=:false)
    xlabel!(pltBVP,"Probability of $(innertxt) Teaching: Data (solid) vs. Model (dot)")
    savefig(pltBVP,figurepath*"/fit_BVP_$(vv).png")
    #
    df_bvp = DataFrame(score_mean_bin = inds, prTeach2010_data=prteach2010[1],
     prTeach2010_model=prteach2010[2], prTeach2011_data=prteach2011[1], prTeach2011_model=prteach2011[2])
    CSV.write(figurepath*"/fit_BVP_$(vv).csv",df_bvp)
    savefig(pltBVP,figurepath*"/fit_BVP_$(vv).eps")    
    #treatment FX
    pyplot()
    pltBVPdif = plot(inds[inds1],prteach2011[1][inds1] .- prteach2010[1][inds1],color=:blue)
    plot!(pltBVPdif,inds[inds2],prteach2011[1][inds2] .- prteach2010[1][inds2],color=:blue)
    plot!(pltBVPdif,inds[inds3],prteach2011[1][inds3] .- prteach2010[1][inds3],color=:blue)
    plot!(pltBVPdif,inds[inds1],prteach2011[2][inds1] .- prteach2010[2][inds1],linestyle = :dot,color=:red)
    plot!(pltBVPdif,inds[inds2],prteach2011[2][inds2] .- prteach2010[2][inds2],linestyle = :dot,color=:red)
    plot!(pltBVPdif,inds[inds3],prteach2011[2][inds3] .- prteach2010[2][inds3],linestyle = :dot,color=:red, legend=:false)
    xlabel!(pltBVPdif,"Probability of $(innertxt) Teaching: 2011 - 2010")
    savefig(pltBVPdif,figurepath*"/fit_BVP_tfx_$(vv).png")
    savefig(pltBVPdif,figurepath*"/fit_BVP_tfx_$(vv).eps")
end
